from datasets import load_dataset, concatenate_datasets
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from distilabel.steps import (KeepColumns, FormatTextGenerationSFT)
import shutil
import os
import pandas as pd



pipeline_cache = '/root/.cache/distilabel/pipelines/distill-qwen-32b-r1-hitom'
if os.path.exists(pipeline_cache):
    shutil.rmtree(pipeline_cache)


prompt_template = """\
You will be given a story, a question, and the corresponding choices. Please reason step by step to answer this question, and put your final answer within \\boxed{}:

Story: {{ story }}
Question: {{question}}
Choices: {{choices}}
"""


dataset1 = load_dataset(".../ToM_data/Hi-ToM", split="train[:60]")
dataset2 = load_dataset(".../ToM_data/Hi-ToM", split="train[100:160]")
dataset3 = load_dataset(".../ToM_data/Hi-ToM", split="train[200:260]")

dataset4 = load_dataset(".../ToM_data/Hi-ToM", split="train[600:660]")
dataset5 = load_dataset(".../ToM_data/Hi-ToM", split="train[700:760]")
dataset6 = load_dataset(".../ToM_data/Hi-ToM", split="train[800:860]")

dataset = concatenate_datasets([dataset1, dataset2, dataset3, dataset4, dataset5, dataset6])



def add_combined_column(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        choices_text = example["choices"]
        if isinstance(choices_text, list):
            choices_text = " ".join(choices_text)
            
        # Create combined text
        example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset = add_combined_column(dataset)
print(dataset)
print(dataset[0])




model_id = ".../distill32B"

with Pipeline(
    name="distill-qwen-32b-r1-hitom",
    description="A pipeline to generate data from a distilled r1 model",
) as pipeline:

    llm = vLLM(
        model=model_id,
        tokenizer=model_id,
        extra_kwargs={
            "tensor_parallel_size": 1,
            "max_model_len": 16384,
        },
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 16384,
        },
    )

    
    text_generation = TextGeneration(
        llm=llm, 
        template=prompt_template,
        num_generations=2,
        input_batch_size=4,
        columns = ["story", "question", "choices"],
    )

    
    
    format_sft = FormatTextGenerationSFT(input_mappings={"instruction": "entire_instruction"})
    

    text_generation.connect(format_sft)
   



if __name__ == "__main__":
    distiset = pipeline.run(dataset=dataset)
    print(distiset)
    print(distiset['default']['train'][0]) 
    distiset.save_to_disk(".../SFTData/HiToM_May_test")
    distiset.load_from_disk(".../SFTData/HiToM_May_test")
    print(distiset)
   